from torch.utils.data import DataLoader
from utils.dataset import TestDataset
import albumentations as A
from albumentations import pytorch as AT
import numpy as np
import pandas as pd
from utils.model import *
import ttach as tta
from tqdm import tqdm
import os
import json
import time

input_size = 224
nc = 137
batch_size = 64
nw = 8
test_csv = '../Dataset/test_clean.csv'  # 缺少 ['a2411.jpg']
test_path = '../Dataset/test/'

model1_name = 'resnest50d'
model2_name = 'tf_efficientnetv2_s_in21ft1k'

save_dir = 'pred-csv'
save_csv = os.path.join(save_dir, 'ress50_effv2s') + os.sep
saveFileName = f'{save_csv}resnest50-effv2s.csv'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
if not os.path.exists(save_csv):
    os.mkdir(save_csv)

albu_transform = {
    'test': A.Compose([
        A.Resize((int(input_size * (256 / 224))), (int(input_size * (256 / 224)))),
        A.CenterCrop(input_size, input_size),
        A.Normalize(),
        AT.ToTensorV2(p=1.0)
    ])
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cuda_info = torch.cuda.get_device_properties(0)
print("using {} {} {}MB.".format(device, cuda_info.name, cuda_info.total_memory / 1024 ** 2))

test_dataset = TestDataset(test_csv, test_path, transform=albu_transform['test'])
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=nw)
print('Test total {} images'.format(len(test_dataset)))
# read num_to_class
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
json_file = open(json_path, "r")
# {'0': 0, '1': 1, ...}
num_to_class = json.load(json_file)

# predict
model1 = timm_model(model1_name, pretrained=False, num_classes=nc).to(device)
model2 = timm_model(model2_name, pretrained=False, num_classes=nc).to(device)

model_path_list = [
    'resnest50d-seed233/epoch28-val_acc-0.9006.pth',
    'tf_efficientnetv2_s_in21ft1k/epoch28-val_acc-0.912.pth',
]

print('--------------------------------------')
print(f'predict-starting')
print('--------------------------------------')

model1.load_state_dict(torch.load(model_path_list[0]))
model2.load_state_dict(torch.load(model_path_list[1]))
print(f'load weight {model_path_list}')
print(f'save name: {saveFileName}')
time.sleep(0.1)

# Make sure the model is in eval mode.
# Some modules like Dropout or BatchNorm affect if the model is in training mode.
model1.eval()
model2.eval()
tta_model1 = tta.ClassificationTTAWrapper(model1, tta.aliases.d4_transform())
tta_model2 = tta.ClassificationTTAWrapper(model2, tta.aliases.d4_transform())

# Initialize a list to store the predictions.
predictions = []
# Iterate the testing set by batches.
for batch in tqdm(test_loader):
    imgs = batch
    with torch.no_grad():
        # logits = model(imgs.to(device))  # do not use tta
        # logits = tta_model(imgs.to(device))
        logits1 = tta_model1(imgs.to(device))
        logits2 = tta_model2(imgs.to(device))
        # logits3 = tta_model3(imgs.to(device))
        # logits4 = tta_model4(imgs.to(device))

        logits = 0.4 * logits1 + 0.6 * logits2

    # Take the class with greatest logit as prediction and record it.
    predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

preds = []
for i in predictions:
    preds.append(num_to_class[str(i)])  # to_str

test_data = pd.read_csv(test_csv)
test_data['category_id'] = pd.Series(preds)
pred_csv = pd.concat([test_data['image_id'], test_data['category_id']], axis=1)
only_one_gif_format = pd.Series({'image_id': 'a2411.jpg', 'category_id': 1})  # append
pred_csv = pred_csv.append(only_one_gif_format, ignore_index=True)
pred_csv.to_csv(saveFileName, index=False)
print(f"{saveFileName}predict done!")
